Molecular Dynamics in Dex

import plot

This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in Dex. For now, the structure of the two implementations is pretty close. However, details look different.

Math

def truncate (x: Float): Float = case x < 0.0 of True -> -floor(-x) False -> floor x
-- A mod that matches np.mod and python mod.
def pmod (x: Float) (y: Float): Float = x - floor(x / y) * y
-- A mod that matches np.fmod and c fmod.
def fmod (x: Float) (y: Float): Float = x - truncate (x / y) * y
def Vec (dim:Type) [Ix dim] : Type = dim=>Float
def sq_norm {dim} (r: Vec dim) : Float = sum $ for i. r.i * r.i
def norm {dim} (r: Vec dim) : Float = sqrt $ sq_norm r

Useful Quantities

This computes a size for a box of the given number of dimensions, such that the given number of particles will fill it with the given density.

def box_size_at_number_density (particle_count: Int) (density: Float) (dim: Int) : Float = pow ((i_to_f particle_count) / density) (1.0 / (i_to_f dim))

Spaces

def Displacement (dim:Type) : Type = Vec dim -> Vec dim -> Vec dim
def Shift (dim:Type) : Type = Vec dim -> Vec dim -> Vec dim
def free_displacement {dim} : Displacement dim = \r_1 r_2. r_1 - r_2
def free_shift {dim} : Shift dim = \r dr. r + dr
def periodic_wrap {dim} (box: Float) (dr: dim=>Float) : (dim=>Float) = for i. (pmod (dr.i + box / 2.0) box) - box / 2.0
def periodic_displacement {dim} (box:Float) : Displacement dim = \r_1 r_2. periodic_wrap box (r_1 - r_2)
def periodic_shift {dim} (box:Float) : Shift dim = \r dr. for i. pmod (r.i + dr.i) box

Energy Functions

def Energy (n: Type) (dim: Type) : Type = (n=>Vec dim) -> Float

The "soft-sphere" energy, as a function of the displacement r.

def soft_sphere (ε: Float) (α: Float) (σ: Float) (r: Float) : Float = case r < σ of True -> ε / α * pow (1 - r / σ) α False -> 0.0
def harmonic_soft_sphere (sigma:Float) (r:Float) : Float = soft_sphere 1.0 2.0 sigma r

Here we have a naive pairwise energy function constructor, that promotes a two-particle energy to a whole-system energy by just applying it to every pair of distinct particles. We start with this to have something to test with.

def pair_energy {dim n} (energy: Float -> Float) (displacement: Displacement dim) (r: n=>Vec dim) : Float = sum $ for i. sum for j:(..<i). energy $ norm $ displacement r.i r.(inject _ j)

Minimization Functions

Now we develop the FIRE Descent algorithm.

data FireDescentState n:Type dim:Type [Ix n, Ix dim] = MkFireDescentState { R: (n=>Vec dim) & V: (n=>Vec dim) & F: (n=>Vec dim) & dt: Float & alpha: Float & n_pos: Int }
def get_position {dim n} (state: FireDescentState n dim) : (n=>Vec dim) = (MkFireDescentState {R, V, F, dt, alpha, n_pos}) = state R
def fire_descent_init {dim n} (dt: Float) (alpha: Float) (energy: Energy n dim) (r: n=>Vec dim) : FireDescentState n dim = force = \rp. -(grad energy) rp V = for i:n j:dim. 0.0 F = force r n_pos: Int = 0 MkFireDescentState {R=r, V, F, dt, alpha, n_pos}
def fire_descent_step {dim n} (shift: Shift dim) (energy: Energy n dim) (state: FireDescentState n dim) : FireDescentState n dim = -- Constants that parameterize the FIRE algorithm. -- TODO: Thread these constants through somehow. -- dougalm@ is there a canonical way to do this? dt_start = 0.1 dt_max = 0.4 n_min = 5 f_inc = 1.1 f_dec = 0.5 f_alpha = 0.99 alpha_start = 0.1 ε = 0.000000001 force = \r. -(grad energy) r -- FIRE algorithm. (MkFireDescentState {R, V, F, dt, alpha, n_pos}) = state -- Do a Velocity-Verlet step. R = for i. shift R.i (V.i *. dt + F.i *. pow dt 2) F_old = F F = force R V = V + dt * 0.5 .* (F_old + F) -- Rescale the velocity. F_norm = sqrt $ sum for (i, j). pow F.i.j 2 V_norm = sqrt $ sum for (i, j). pow V.i.j 2 V = V + alpha .* (F *. V_norm / (F_norm + ε) - V) -- Check whether the force is aligned with the velocity. FdotV = sum for (i, j). F.i.j * V.i.j -- Decide whether to increase the speed of the simulation or reduce it. (n_pos, dt, alpha) = if FdotV >= 0.0 then dt' = if n_pos >= n_min then (min (dt * f_inc) dt_max) else dt alpha' = if n_pos >= n_min then (alpha * f_alpha) else alpha (n_pos + 1, dt', alpha') else (0, dt * f_dec, alpha_start) V = if FdotV >= 0.0 then V else zero MkFireDescentState { R, V, F, dt, alpha, n_pos }

Drawing

import png
import diagram

Now a tool to draw a two-dimensional system, where each particle is a disk of given size.

TwoDimensions = Fin 2
def draw_system {n} radius (r: n=>Vec TwoDimensions) : Diagram = disks = concat_diagrams for i. circle radius |> move_xy (r.i.(0 @ TwoDimensions), r.i.(1 @ TwoDimensions)) disks

Example

Initialize a system randomly.

N_small = 500
d = 2
L_small = box_size_at_number_density (n_to_i N_small) 1.2 (n_to_i d)
-- We will simulate in a box of this side length
L_small
20.412415
R_init_small = rand_mat N_small d (\k. L_small * rand k) (new_key 0)

The initial state of our random system

%time :html render_svg (draw_system 0.5 R_init_small) ((0.0, 0.0), (L_small, L_small))
Compile time: 310.405 ms Run time: 20.864 ms

Define energy function. Note the preiodic_displacement, which means our system will be evolving on a torus.

def energy {n d} (pos: n=>d=>Float) : Float = pair_energy (harmonic_soft_sphere 1.0) (periodic_displacement L_small) pos

Here's the initial energy we compute for our system.

:t energy R_init_small
Float32
energy R_init_small
74.690056

Initialize a simulation

state_small = fire_descent_init 0.1 0.1 energy R_init_small

and test one step of minimization. The energy decreases from the initial, as expected:

energy $ get_position $ fire_descent_step free_shift energy state_small
71.78407

Now we can test that our code basically works by running 100 steps of minimization.

%time (state_small', energies) = scan state_small \i:(Fin 100) s. s' = fire_descent_step (periodic_shift L_small) energy s (s', energy $ get_position s')
Compile time: 815.015 ms Run time: 1.498 s

Here's how the energy decreases over time.

%time :html show_plot $ y_plot energies
Compile time: 728.691 ms Run time: 4.085 ms

Here's what the system looks like after minimization.

%time :html render_svg (draw_system 0.5 (get_position state_small')) ((0.0, 0.0), (L_small, L_small))
Compile time: 308.200 ms Run time: 18.507 ms

Neighbors optimization

The above pair_energy function will compute the influence of every atom on every other atom, regardless of how far apart they are.

To simulate more efficiently, we'd like to approximate the pairwise energy with an energy that only includes contributions from atoms that are close enough to each other that we wish not to neglect them.

This is a two-step operation:

  • Break the simulation volume into a grid of cells, and do a linear pass over the atoms to group them by which cell each is in.
  • Traverse every pair of adjacent cells and compute energy terms for every pair of atoms only in those cells, and no others.

Bounded lists

We start with an abstraction of an incrementally growable list. To get O(1) insertion at the end, we (currently) have to give an upper bound for the list's size.

-- TODO Can we encapsulate this BoundedList type as a `data` and still
-- define in-place mutation operations on it?
def BoundedList n [Ix n] a = (n & (n => a))
def unsafe_next_index {n} [Ix n] (ix:n) : n = unsafe_from_ordinal n $ ordinal ix + 1
def empty_bounded_list {n a} [Ix n] (dummy_val: a) : BoundedList n a = (unsafe_from_ordinal _ 0, for _. dummy_val)
-- The point of a `BoundedList` is O(1) push by in-place mutation
def bd_push {h n a} [Ix n] (lst_ref: Ref h (BoundedList n a)) (x: a) : {State h} Unit = sz_ref = fst_ref lst_ref sz = get sz_ref buf_ref = snd_ref lst_ref if ordinal sz < size n then buf_ref!sz := x sz_ref := unsafe_next_index sz else todo -- throw ()
-- Once we're done pushing, we can compact a `BoundedList` into a standard `List`.
def as_list {n a} (lst: BoundedList n a) : List a = (lim, buf) = lst n_result = ordinal lim AsList _ $ view i:(Fin n_result). buf.(unsafe_from_ordinal _ $ ordinal i)

Cell list

We define a single index for the whole grid.

def GridIx dim grid_size [Ix dim] = dim => (Fin grid_size)

A cell list is now just a BoundedList of the (indices of) the atoms that appear in each cell in the grid.

def CellTable dim grid_size bucket_size atom_ix [Ix dim] = GridIx dim grid_size => BoundedList (Fin bucket_size) atom_ix
-- Compute the cell an atom should go into.
def target_cell {dim grid_size} (cell_size: Float) (atom: Vec dim) : GridIx dim grid_size = for dim. from_ordinal _ $ f_to_n $ atom.dim / cell_size
-- A traversal of an atom array together with target cell.
-- We abstract this because we will use it twice.
def traverse_cells {dim atom_ix eff} [Ix atom_ix] grid_size cell_size (atoms: atom_ix => Vec dim) (action: atom_ix -> (GridIx dim grid_size) -> {|eff} Unit) : {|eff} Unit = for_ ix. cell = target_cell cell_size atoms.ix action ix cell

Here is the actual cell table computation:

def cell_table {dim atom_ix} [Ix atom_ix] grid_size bucket_size cell_size (atoms: atom_ix => Vec dim) : (CellTable dim grid_size bucket_size atom_ix) = yield_state (for _. empty_bounded_list $ unsafe_from_ordinal _ 0) \ref. traverse_cells grid_size cell_size atoms \ix cell. bd_push ref!cell ix
-- This is a helper for computing the bucket size. Right now we end
-- up traversing the input twice (once to compute a size and once to
-- actually build the cell table); this is working around limitations
-- in Dex's support for mutable lists of statically unknown length.
def cell_table_bucket_size {dim atom_ix} [Ix atom_ix] grid_size cell_size (atoms: atom_ix => Vec dim) : Nat = -- TODO Use yield_accum here bucket_sizes = yield_state (for _:(GridIx dim grid_size). 0) \ref. traverse_cells grid_size cell_size atoms \ix cell. ref!cell := get ref!cell + 1 maximum bucket_sizes

Let's test it out on our "small" system:

def grid_params L desired_cell_size : _ = cells_per_side = floor(L / desired_cell_size) cell_size = L / cells_per_side grid_size = unsafe_i_to_n $ f_to_i cells_per_side cell_size, grid_size
desired_cell_size = 1.0 -- unit interaction range
cell_size, grid_size = grid_params L_small desired_cell_size
bucket_size = 10
%time tbl = cell_table grid_size bucket_size cell_size $ get_position state_small'
Compile time: 132.819 ms Run time: 36.200 us

We have a table of cells with atoms in them

:t tbl
(((Fin 2) => Fin 20) => (Fin 10 & ((Fin 10) => Fin 500)))

And here are the atoms in the 0th cell.

as_list tbl.(unsafe_from_ordinal _ 0)
(AsList 2 [(33@(Fin 500)), (237@(Fin 500))])

Now let's compute pairs of neighbors from our cell list

We'll specialize to two dimensions for now, but broadening to more is not difficult.

cell_neighbors_2d = [[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]]
:t cell_neighbors_2d
((Fin 9) => (Fin 2) => Int32)
-- Toroidal index offsetting
def torus_offset {n} (ix: (Fin n)) (offset: Int) : (Fin n) = unsafe_from_ordinal _ $ unsafe_i_to_n $ mod (n_to_i (ordinal ix) + offset) (n_to_i n)
-- A traversal of pairs of atoms from adjacent or equal cells in a
-- cell table.
def traverse_pairs {grid_size bucket_size atom_ix eff} (tbl: CellTable TwoDimensions grid_size bucket_size atom_ix) (atoms: atom_ix => Vec TwoDimensions) (action: atom_ix -> atom_ix -> {|eff} Unit) : {|eff} Unit = for_ cell_ix : (GridIx TwoDimensions grid_size). for_ nb. displacement = cell_neighbors_2d.nb neighbor_ix = for dim. torus_offset cell_ix.dim displacement.dim (AsList sz_atoms_1 atoms_1) = as_list tbl.cell_ix for_ atom1_ix : (Fin sz_atoms_1). atom1 = atoms_1.atom1_ix (AsList sz_atoms_2 atoms_2) = as_list tbl.neighbor_ix for_ atom2_ix : (Fin sz_atoms_2). atom2 = atoms_2.atom2_ix action atom1 atom2

The neighbor list computation. The point of the exercise is that this is not O(#atoms^2), but rather O(#cells) * 9 * O(#atoms per cell^2), because it only considers atoms from adjacent cells as potential neighbors.

def neighbor_list {grid_size bucket_size atom_ix} neighbor_list_size (tbl: CellTable TwoDimensions grid_size bucket_size atom_ix) (is_neighbor: atom_ix -> atom_ix -> Bool) (atoms: atom_ix => Vec TwoDimensions) : BoundedList (Fin neighbor_list_size) (atom_ix & atom_ix) = yield_state (empty_bounded_list $ (from_ordinal _ 0, from_ordinal _ 0)) \ref. traverse_pairs tbl atoms (\atom1 atom2. if is_neighbor atom1 atom2 then bd_push ref (atom1, atom2))
-- Also a helper to pre-compute how large a buffer to allocate for the
-- neighbor list, by traversing the input an extra time.
def count_neighbor_list_size {grid_size bucket_size atom_ix} (tbl: CellTable TwoDimensions grid_size bucket_size atom_ix) (is_neighbor: atom_ix -> atom_ix -> Bool) (atoms: atom_ix => Vec TwoDimensions) : Nat = yield_accum (AddMonoid Nat) \ref. traverse_pairs tbl atoms (\atom1 atom2. if is_neighbor atom1 atom2 then ref += 1)
def periodic_near {atom_ix} desired_cell_size L (atoms: atom_ix => Vec TwoDimensions) (atom1:atom_ix) (atom2:atom_ix) : Bool = dist = norm $ periodic_displacement L atoms.atom1 atoms.atom2 dist < desired_cell_size

Let's test this one out on our "small" system.

%time res = (neighbor_list 4000 tbl (periodic_near 1.0 L_small $ get_position state_small') $ get_position state_small')
Compile time: 174.558 ms Run time: 219.700 us

In that configuration, we find this many pairs of neighbors:

(AsList k _) = as_list res
k
3090

Now that we have the concept of neighbor lists, we cen define a variant of pair_energy that only considers atoms that the neighbor list says are close.

def pair_energy_nl {dim n} (energy: Float -> Float) (displacement: Displacement dim) (r: n=>Vec dim) (neighbors: List (n & n)) : Float = (AsList k nbrs) = neighbors sum for i. a1ix, a2ix = nbrs.i case (ordinal a1ix) < (ordinal a2ix) of True -> energy $ norm $ displacement r.a1ix r.a2ix False -> 0.0
def energy_nl {n d} L (neighbors: List (n & n)) (pos: n=>d=>Float) : Float = pair_energy_nl (harmonic_soft_sphere 1.0) (periodic_displacement L) pos neighbors

And here's the check that it computes near-identical results to our original, fully pairwise energy function.

energy_nl L_small (as_list res) (get_position state_small')
1.230908
energy (get_position state_small')
1.230907
-- Package the above up into a function that just computes the
-- neighbor list from an array of atoms.
def just_neighbor_list {n} desired_cell_size L (atoms:n=>(Fin 2)=>Float) : (List (n & n)) = cell_size, grid_size = grid_params L desired_cell_size bucket_size = cell_table_bucket_size grid_size cell_size atoms tbl = cell_table grid_size bucket_size cell_size atoms is_neighbor = periodic_near desired_cell_size L atoms neighbor_list_sz = count_neighbor_list_size tbl is_neighbor atoms res = neighbor_list neighbor_list_sz tbl is_neighbor atoms as_list res

The neighbor list energy function works as an argument to our FIRE Descent algorithm (provided the neighbor list uses the same atoms).

state_nl = energy_func = (energy_nl L_small $ just_neighbor_list 1.0 L_small R_init_small) fire_descent_init 0.1 0.1 energy_func R_init_small
energy R_init_small
74.690056
energy $ get_position $ fire_descent_step free_shift energy state_small
71.78407
-- A helper for short-circuiting `any` computation
def fast_any {n eff} [Ix n] (f: n -> {|eff} Bool) : {|eff} Bool = iter \ct. if ct >= size n then Done False else if f (ct @ n) then Done True else Continue

And now that this basically works, we can package the whole thing up as a simulation. We have another trick here: we compute the neighbor list with an extra "halo", treating atoms as neighbors if they are distance 1 + halo from each other, rather than just the interaction range 1. This way, we only have to recompute the neighbor list when some atom moves more than halo/2 away from where it was when the neighbor list is computed, because otherwise it remains a sound approximation.

-- TODO(Issue 1133) Can't use scan with a body that has an effect?
def simulate {atom_ix} (displacement : Displacement TwoDimensions) halo_size L time [Ix time] (state : FireDescentState atom_ix TwoDimensions) : {IO} (FireDescentState atom_ix TwoDimensions & time => Float) = with_state (get_position state) \saved_atoms_ref. nbrs = just_neighbor_list (1.0 + halo_size) L $ get_position state (AsList k _) = nbrs print $ show k <> " initial neighbor list size" with_state nbrs \saved_neighbors_ref. swap $ run_state state \s_ref. for i. s = get s_ref new_atoms = get_position s rebuild = fast_any \i. 2 * norm (displacement (get saved_atoms_ref!i) new_atoms.i) > halo_size if rebuild then saved_atoms_ref := new_atoms nbrs = just_neighbor_list (1.0 + halo_size) L new_atoms saved_neighbors_ref := nbrs (AsList k _) = nbrs print $ show k <> " new neighbor list size" nbrs = get saved_neighbors_ref s' = fire_descent_step (periodic_shift L) (energy_nl L nbrs) s s_ref := s' energy_nl L nbrs $ get_position s'

Let's check that it works on our test system from before

%time (state_nl', energies_nl) = unsafe_io do simulate (periodic_displacement L_small) 0.5 L_small (Fin 100) state_small
4568 initial neighbor list size
4614 new neighbor list size
4564 new neighbor list size
4376 new neighbor list size
4346 new neighbor list size
4266 new neighbor list size
4178 new neighbor list size
4140 new neighbor list size
4100 new neighbor list size
4028 new neighbor list size
4006 new neighbor list size
3922 new neighbor list size
3868 new neighbor list size
3810 new neighbor list size
3762 new neighbor list size
3720 new neighbor list size
3656 new neighbor list size
3640 new neighbor list size
3602 new neighbor list size
3572 new neighbor list size
3554 new neighbor list size
Compile time: 1.901 s Run time: 28.524 ms
%time :html show_plot $ y_plot energies_nl
Compile time: 724.445 ms Run time: 4.786 ms
%time :html render_svg (draw_system 0.5 (get_position state_nl')) ((0.0, 0.0), (L_small, L_small))
Compile time: 305.757 ms Run time: 18.192 ms

But of course the point of the exercise is that this now scales up to larger systems because it avoids the quadratic energy computation.

N_large = if not (dex_test_mode ()) then 50000 else 500
L_large = box_size_at_number_density (n_to_i N_large) 1.2 (n_to_i d)
L_large
204.12415
R_init_large = rand_mat N_large d (\k. L_large * rand k) (new_key 0)

Initial state (we render the atoms smaller now so they don't over-plot too badly).

%time :html render_svg (draw_system 0.2 R_init_large) ((0.0, 0.0), (L_large, L_large))
Compile time: 717.270 ms Run time: 1.975 s
state_large = energy_func = (energy_nl L_large $ just_neighbor_list 1.0 L_large R_init_large) fire_descent_init 0.1 0.1 energy_func R_init_large
%time (state_large_nl', energies_large_nl) = unsafe_io do simulate (periodic_displacement L_large) 0.5 L_large (Fin 100) state_large
474778 initial neighbor list size
474090 new neighbor list size
471628 new neighbor list size
465540 new neighbor list size
453724 new neighbor list size
442298 new neighbor list size
434296 new neighbor list size
428432 new neighbor list size
421936 new neighbor list size
418256 new neighbor list size
414658 new neighbor list size
411276 new neighbor list size
404294 new neighbor list size
398566 new neighbor list size
393964 new neighbor list size
390034 new neighbor list size
386788 new neighbor list size
382868 new neighbor list size
379508 new neighbor list size
376434 new neighbor list size
374030 new neighbor list size
371268 new neighbor list size
368936 new neighbor list size
366960 new neighbor list size
365348 new neighbor list size
363868 new neighbor list size
362176 new neighbor list size
360944 new neighbor list size
360136 new neighbor list size
359590 new neighbor list size
358944 new neighbor list size
358036 new neighbor list size
357444 new neighbor list size
357094 new neighbor list size
Compile time: 2.202 s Run time: 7.192 s

Energy decrease

%time :html show_plot $ y_plot energies_large_nl
Compile time: 719.589 ms Run time: 3.887 ms

System state after minimization.

%time :html render_svg (draw_system 0.2 (get_position state_large_nl')) ((0.0, 0.0), (L_large, L_large))
Compile time: 485.198 ms Run time: 1.954 s